from torch.utils.data import DataLoader,TensorDataset
from torch.utils.data import random_split
from torchdata.datapipes.map import MapDataPipe
from pytorch_lightning import LightningDataModule
from game24_utils import *
import warnings
import sys
from util import *
from bw_utils import *
sys.path.append("gpt-plan-benchmark/gpt_plan_test")
warnings.filterwarnings("ignore", ".*does not have many workers.*")
import yaml
import json
from tarski.io import PDDLReader
import pandas as pd
import torch
import random

def get_problem(instance, domain):
    reader = PDDLReader(raise_on_error=True)
    reader.parse_domain(domain)
    return reader.parse_instance(instance)


class InputExample():
    def __init__(self,input_ids, labels, attention_masks, reward):
        self.input_ids = input_ids
        self.labels = labels
        self.attention_masks = attention_masks
        self.reward = reward


class Game24DataModule(LightningDataModule):
    def __init__(
            self,
            args,
            tokenizer,
            train_size = 0.8,
            device = "cuda",
            limit_prompts=None,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.args = args
        self.train_data = None
        self.val_data = None
        self.train_size = train_size
        self.test_data = None
        self.device = device

    def setup(self,stage=None):
        print(stage)
        if stage=="fit" or stage is None:
            print("Loading data")
            game24full = []
            with open(self.args.train_data,'r') as f:
                for line in f:
                    game24full.append(json.loads(line))
            generator = torch.Generator().manual_seed(42)
            # self.train_data = game24full
            features = self.convert_train_to_features(game24full)
            all_inputs_ids = torch.stack([f.input_ids for f in features])
            all_labels = torch.stack([f.labels for f in features])
            all_attention_masks = torch.stack([f.attention_masks for f in features])
            all_rewards = torch.Tensor([f.reward for f in features])
            train_data = TensorDataset(all_inputs_ids, all_labels, all_attention_masks, all_rewards)
            self.train_data = PromptDataPipe(train_data)
            tests_all = list(pd.read_csv('data/24.csv')['Puzzles'])
            vald = tests_all[10:20] + tests_all[-20:-10]
            self.val_data = PromptDataPipe(vald)
            # print(test_remail)
            # print(self.val_data)
        elif stage=='test':
            tests_all = list(pd.read_csv('data/24.csv')['Puzzles'])
            self.test_data = PromptDataPipe(tests_all[910:1010])
            # with open('data/24/test.json','r') as f:
            #     test = json.load(f)
            #     self.test_data = PromptDataPipe(test)

    

    def creat_labels(self,inputs,generate_text,ignore_token_id):
        labels = inputs["input_ids"].clone()
        labels[:,:len(inputs["input_ids"][0])-len(self.tokenizer(generate_text)["input_ids"])] = ignore_token_id
        return labels
        start = False
        for i in range(len(labels[0])):
            if labels[0][i] == self.tokenizer.encode('=')[0]:
                start = True
            if start:
                labels[0][i] = ignore_token_id
            if labels[0][i] == 13:# '\n'
                start = False

        return labels


    def convert_train_to_features(self, examples,max_length=1024):
        # if self.args.do_sft:
        #     max_length = 128
        ignore_token_id = LabelSmoother.ignore_index 
        features = [] 
        i = 0
        query = -1
        for example in examples:
            if self.args.do_sft or self.args.do_rej:
                if example['reward'] != 100:
                    continue
                elif example['idx'] == query:
                    continue
                else:
                    query = example['idx']
                    if self.args.do_sft and not self.args.do_cot:
                        input_prompt = standard_sft_prompt_wrap(example['input']) + example['answer'] + " = 24"
                        generate_text = example['answer'] + " = 24"
                    elif self.args.do_cot:
                        input_prompt = cot_prompt_wrap(example['input']) + example['generate_data']
                        generate_text = "Steps: \n" + example['generate_data']  
                    elif self.args.do_rej:
                        input_prompt = cot_prompt_wrap(example['input']) + example['generate_data']
                        generate_text = "Steps: \n" + example['generate_data'] 
            else:
                input_prompt = cot_prompt_wrap(example['input']) + example['generate_data']
                generate_text = "Steps: \n" + example['generate_data']
            
            inputs = self.tokenizer(input_prompt,return_tensors="pt")
            # print(inputs['input_ids'])
            # print(inputs['input_ids'].shape)
            if max_length < len(inputs["input_ids"][0]):
                print("Input length is greater than max_length")
                print(inputs["input_ids"].shape[1])
                input()
            padding_length = max_length - inputs["input_ids"].shape[1]
            labels = self.creat_labels(inputs,generate_text,ignore_token_id)
            attention_mask = torch.ones_like(inputs["input_ids"])
            padded_input_ids = torch.cat([inputs['input_ids'], torch.full((padding_length,), self.tokenizer.eos_token_id, dtype=torch.long).unsqueeze(0)] ,dim=-1)[0]
            padded_attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=torch.long).unsqueeze(0)], dim=-1)[0]
            padded_labels = torch.cat([labels, torch.full((padding_length,), self.tokenizer.eos_token_id, dtype=torch.long).unsqueeze(0)], dim=-1)[0]
            features.append(InputExample(
                input_ids=padded_input_ids,
                labels=padded_labels,
                attention_masks=padded_attention_mask,
                reward=example['reward']
            ))
            if i < 5:
                print("***Example: ***", i)
                print("lenngth of input_ids:", len(inputs["input_ids"][0]))
                # print("padded_input_ids:", padded_input_ids)
                # print("padded_attention_mask:", padded_attention_mask)
                # print("padded_labels:", padded_labels)
                print("padded_input_tokens:", self.tokenizer.decode(padded_input_ids))
                # print("padded_attention_mask", self.tokenizer.decode(padded_attention_mask))
                # print("padded_labels:", self.tokenizer.decode(padded_labels))
                # exit()
                i += 1
        return features
    
    def train_dataloader(self):
        return DataLoader(self.train_data,batch_size=self.args.batch_size,shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=1) 
    
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=1)

class PromptDataModule(LightningDataModule):
    def __init__(
        self,
        args,
        tokenizer,
        train_size=0.8,
        limit_prompts=None,
    ):
        super().__init__()
        self.save_hyperparameters(ignore="tokenizer")
        with open('data/blocksworld/bw_config.yaml', 'r') as file:
            self.data = yaml.safe_load(file)
        self.prompts = json.load(open("data/blocksworld/my_mcts_prompts_update.json", 'r'))
        with open('data/blocksworld/bw_config.yaml', 'r') as file:
            self.config = yaml.safe_load(file)
        self.domain_pddl = f'gpt-plan-benchmark/gpt_plan_test/instances/{self.config["domain_file"]}'
        self.base_prompt = "I am playing with a set of blocks where I need to arrange the blocks into stacks. Here are the actions I can do\n\nPick up a block\nUnstack a block from on top of another block\nPut down a block\nStack a block on top of another block\n\nI have the following restrictions on my actions:\nI can only pick up or unstack one block at a time.\nI can only pick up or unstack a block if my hand is empty.\nI can only pick up a block if the block is on the table and the block is clear. A block is clear if the block has no other blocks on top of it and if the block is not picked up.\nI can only unstack a block from on top of another block if the block I am unstacking was really on top of the other block.\nI can only unstack a block from on top of another block if the block I am unstacking is clear.\nOnce I pick up or unstack a block, I am holding the block.\nI can only put down a block that I am holding.\nI can only stack a block on top of another block if I am holding the block being stacked.\nI can only stack a block on top of another block if the block onto which I am stacking the block is clear.\nOnce I put down or stack a block, my hand becomes empty.\nAfter being given an initial state and an action, give the new state after performing the action.\n"
        self.tokenizer = tokenizer
        self.args = args
        self.train_data = None
        self.val_data = None

    def setup(self, stage):
        all_data = []
        train_data = json.load(open(f"/home/fangxu/GFlowPlan/data/blocksworld/step_{self.args.step}.json", 'r'))
        
        for d in train_data:
            problem = get_problem(d[0], self.domain_pddl)
            gt_plan_text = d[1]
            INIT, GOAL, PLAN = instance_to_text_blocksworld(problem, True, gt_plan_text, self.data)
            # all_data.append([INIT, GOAL, PLAN])
            # initial_state = f"I have that, {INIT}."

            # state = self.base_prompt + self.prompts["goal_prefix"] + GOAL.strip() + "\n" + self.prompts["state_prefix"].format(0) + " " + initial_state.strip() + "\n"
            all_data.append([INIT, GOAL, PLAN])
        if self.hparams.limit_prompts is not None:
            all_data = all_data[: self.hparams.limit_prompts]
        num_train = int(len(all_data) * self.hparams.train_size)
        self.train_data = PromptDataPipe(all_data[:num_train])
        self.val_data = PromptDataPipe(all_data[num_train:])


    
    def train_dataloader(self):
        return DataLoader(self.train_data, shuffle=True, batch_size=1)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=1)


class PromptDataPipe(MapDataPipe):
    def __init__(self, problems) -> None:
        super().__init__()
        self.problems = problems

    def __len__(self):
        return len(self.problems)

    def __getitem__(self, index):

        return self.problems[index]